import argparse
import json
import torch
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import pdb

def format_question(example):
    choices = example["choices"]["text"]
    lettered = [f"{chr(65+i)}. {choice}" for i, choice in enumerate(choices)]
    return f"""Answer the following multiple-choice question by selecting the most suitable option (A, B, C, D, or E).

Q: {example['question']}
Choices:
{chr(65)}. {choices[0]}
{chr(66)}. {choices[1]}
{chr(67)}. {choices[2]}
{chr(68)}. {choices[3]}
{chr(69)}. {choices[4]}
Answer:"""


def add_icl_examples_to_prompt_with_last_question(indices, dataset, last_question_dict):
    icl_examples = []
    
    for idx in indices:
        example = dataset[idx]
        question = example["question"]
        choices = example["choices"]["text"]
        answer = example["answerKey"]
        
        icl_example = f"""Q: {question}
Choices:
A. {choices[0]}
B. {choices[1]}
C. {choices[2]}
D. {choices[3]}
E. {choices[4]}
Answer: {answer}"""
        
        icl_examples.append(icl_example)
    
    last_question = last_question_dict["question"]
    last_choices = last_question_dict["choices"]["text"]
    
    last_question_prompt = f"""Q: {last_question}
Choices:
A. {last_choices[0]}
B. {last_choices[1]}
C. {last_choices[2]}
D. {last_choices[3]}
E. {last_choices[4]}
Answer:"""

    system_prompt = "Below are multiple-choice questions with five answer choices each. For each question, select the most appropriate answer (A, B, C, D, or E). Learn from the examples and answer the final question in the same format. \n"
    full_prompt = system_prompt + "\n\n".join(icl_examples) + "\n\n" + last_question_prompt
    return full_prompt


def main(args):
    from src.dataset_readers.dataset_wrappers import get_dataset_wrapper

    dataset_wrapper = get_dataset_wrapper(args.task_name)
    print(f"Loading dataset {args.task_name} from {dataset_wrapper.hf_dataset}")

    dataset = load_dataset(dataset_wrapper.hf_dataset, split=args.eval_split)
    ice_dataset = load_dataset(dataset_wrapper.hf_dataset, split=args.ice_split)

    with open(args.ice_example_file, 'r') as f:
        ice_examples = json.load(f)

    tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)

    bnb_config = BitsAndBytesConfig(
        load_in_8bit=True,
        bnb_8bit_compute_dtype=torch.float16
    )

    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True
    )
    model.eval()

    correct = 0
    total = 0
    K = args.k_shot

    number_of_samples = args.run_for_n_samples
    if number_of_samples > 0:
        N = number_of_samples
    else:
        N = len(dataset)

    for i, ex in tqdm(enumerate(dataset.select(range(N))), total=N):
        pdb.set_trace()

        if not args.no_ice_example:
            indices = ice_examples[i]['ctxs']
            prompt = add_icl_examples_to_prompt_with_last_question(
                indices=indices[:K],
                dataset=ice_dataset,
                last_question_dict=ex
            )
        else:
            prompt = format_question(ex)

        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=10,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id
            )

        decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
        answer_letter = ex["answerKey"]
        try:
            pred = decoded.strip().split("Answer:")[-1].strip().split()[0].strip().upper()
        except Exception:
            pred = ""

        if pred == answer_letter:
            correct += 1
        total += 1

    accuracy = correct / total
    print(f"Accuracy on {args.task_name}: {accuracy:.2%}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--task_name", type=str, default="cmsqa")
    parser.add_argument("--eval_split", type=str, default="validation")
    parser.add_argument("--ice_split", type=str, default="train")
    parser.add_argument("--ice_example_file", type=str, required=True)
    parser.add_argument("--model_path", type=str, required=True)
    parser.add_argument("--k_shot", type=int, default=5)
    parser.add_argument("--no_ice_example", action="store_true", help="Use this flag to disable ICE examples")
    parser.add_argument("--run_for_n_samples", type=int, default=0, help="Number of examples to run for")

    args = parser.parse_args()
    main(args)
